use std::{collections::HashMap, i32, ops::Bound};

use diesel::{
    pg::{sql_types, Pg},
    query_builder::{AstPass, QueryFragment, QueryId},
    sql_query,
    sql_types::{Binary, Bool, Integer, Jsonb, Nullable},
    ExpressionMethods, OptionalExtension, QueryDsl, QueryResult,
};
use diesel_async::RunQueryDsl;

use graph::{
    anyhow::{anyhow, Context},
    components::store::{write, StoredDynamicDataSource},
    data_source::CausalityRegion,
    internal_error,
    prelude::{serde_json, BlockNumber, StoreError},
};

use crate::{primary::Namespace, relational_queries::POSTGRES_MAX_PARAMETERS, AsyncPgConnection};

type DynTable = diesel_dynamic_schema::Table<String, Namespace>;
type DynColumn<ST> = diesel_dynamic_schema::Column<DynTable, &'static str, ST>;

#[derive(Debug)]
pub(crate) struct DataSourcesTable {
    namespace: Namespace,
    qname: String,
    table: DynTable,
    vid: DynColumn<Integer>,
    block_range: DynColumn<sql_types::Range<Integer>>,
    causality_region: DynColumn<Integer>,
    manifest_idx: DynColumn<Integer>,
    param: DynColumn<Nullable<Binary>>,
    context: DynColumn<Nullable<Jsonb>>,
    done_at: DynColumn<Nullable<Integer>>,
}

impl DataSourcesTable {
    const TABLE_NAME: &'static str = "data_sources$";

    pub(crate) fn new(namespace: Namespace) -> Self {
        let table =
            diesel_dynamic_schema::schema(namespace.clone()).table(Self::TABLE_NAME.to_string());

        DataSourcesTable {
            qname: format!("{}.{}", namespace, Self::TABLE_NAME),
            namespace,
            vid: table.column("vid"),
            block_range: table.column("block_range"),
            causality_region: table.column("causality_region"),
            manifest_idx: table.column("manifest_idx"),
            param: table.column("param"),
            context: table.column("context"),
            done_at: table.column("done_at"),
            table,
        }
    }

    pub(crate) fn as_ddl(&self) -> String {
        format!(
            "
            create table {nsp}.{table} (
                vid integer generated by default as identity primary key,
                block_range int4range not null,
                causality_region integer not null,
                manifest_idx integer not null,
                parent integer references {nsp}.{table},
                id bytea,
                param bytea,
                context jsonb,
                done_at int
            );

            create index gist_block_range_data_sources$ on {nsp}.data_sources$ using gist (block_range);
            create index btree_causality_region_data_sources$ on {nsp}.data_sources$ (causality_region);
            ",
            nsp = self.namespace.to_string(),
            table = Self::TABLE_NAME
        )
    }

    // Query to load the data sources which are live at `block`. Ordering by the creation block and
    // `vid` makes sure they are in insertion order which is important for the correctness of
    // reverts and the execution order of triggers. See also 8f1bca33-d3b7-4035-affc-fd6161a12448.
    pub(super) async fn load(
        &self,
        conn: &mut AsyncPgConnection,
        block: BlockNumber,
    ) -> Result<Vec<StoredDynamicDataSource>, StoreError> {
        type Tuple = (
            (Bound<i32>, Bound<i32>),
            i32,
            Option<Vec<u8>>,
            Option<serde_json::Value>,
            CausalityRegion,
            Option<i32>,
        );
        let tuples = self
            .table
            .clone()
            .filter(diesel::dsl::sql::<Bool>("block_range @> ").bind::<Integer, _>(block))
            .select((
                &self.block_range,
                &self.manifest_idx,
                &self.param,
                &self.context,
                &self.causality_region,
                &self.done_at,
            ))
            .order_by(&self.vid)
            .load::<Tuple>(conn)
            .await?;

        let mut dses: Vec<_> = tuples
            .into_iter()
            .map(
                |(block_range, manifest_idx, param, context, causality_region, done_at)| {
                    let creation_block = match block_range.0 {
                        Bound::Included(block) => Some(block),

                        // Should never happen.
                        Bound::Excluded(_) | Bound::Unbounded => {
                            unreachable!("dds with open creation")
                        }
                    };

                    StoredDynamicDataSource {
                        manifest_idx: manifest_idx as u32,
                        param: param.map(|p| p.into()),
                        context,
                        creation_block,
                        done_at,
                        causality_region,
                    }
                },
            )
            .collect();

        // This sort is stable and `tuples` was ordered by vid, so `dses` will be ordered by `(creation_block, vid)`.
        dses.sort_by_key(|v| v.creation_block);

        Ok(dses)
    }

    pub(crate) async fn insert(
        &self,
        conn: &mut AsyncPgConnection,
        data_sources: &write::DataSources,
    ) -> Result<usize, StoreError> {
        let mut inserted_total = 0;

        for (block_ptr, dss) in &data_sources.entries {
            let block = block_ptr.number;
            for ds in dss {
                let StoredDynamicDataSource {
                    manifest_idx,
                    param,
                    context,
                    creation_block,
                    done_at,
                    causality_region,
                } = ds;

                // Nested offchain data sources might not pass this check, as their `creation_block`
                // will be their parent's `creation_block`, not necessarily `block`.
                if causality_region == &CausalityRegion::ONCHAIN && creation_block != &Some(block) {
                    return Err(internal_error!(
                        "mismatching creation blocks `{:?}` and `{}`",
                        creation_block,
                        block
                    ));
                }

                // Offchain data sources have a unique causality region assigned from a sequence in the
                // database, while onchain data sources always have causality region 0.
                let query = format!(
                "insert into {}(block_range, manifest_idx, param, context, causality_region, done_at) \
                            values (int4range($1, null), $2, $3, $4, $5, $6)",
                self.qname
            );

                let query = sql_query(query)
                    .bind::<Nullable<Integer>, _>(creation_block)
                    .bind::<Integer, _>(*manifest_idx as i32)
                    .bind::<Nullable<Binary>, _>(param.as_ref().map(|p| &**p))
                    .bind::<Nullable<Jsonb>, _>(context)
                    .bind::<Integer, _>(causality_region)
                    .bind::<Nullable<Integer>, _>(done_at);

                inserted_total += query.execute(conn).await?;
            }
        }
        Ok(inserted_total)
    }

    pub(crate) async fn revert(
        &self,
        conn: &mut AsyncPgConnection,
        block: BlockNumber,
    ) -> Result<(), StoreError> {
        // Use the 'does not extend to the left of' operator `&>` to leverage the gist index, this
        // is equivalent to lower(block_range) >= $1.
        //
        // This assumes all ranges are of the form [x, +inf), and thefore no range needs to be
        // unclamped.
        let query = format!(
            "delete from {} where block_range &> int4range($1, null)",
            self.qname
        );
        sql_query(query)
            .bind::<Integer, _>(block)
            .execute(conn)
            .await?;
        Ok(())
    }

    /// Copy the dynamic data sources from `self` to `dst`. All data sources that
    /// were created up to and including `target_block` will be copied.
    pub(crate) async fn copy_to(
        &self,
        conn: &mut AsyncPgConnection,
        dst: &DataSourcesTable,
        target_block: BlockNumber,
        src_manifest_idx_and_name: &[(i32, String)],
        dst_manifest_idx_and_name: &[(i32, String)],
    ) -> Result<usize, StoreError> {
        // Check if there are any data sources for dst which indicates we already copied
        let count = dst.table.clone().count().get_result::<i64>(conn).await?;
        if count > 0 {
            return Ok(count as usize);
        }

        let manifest_map =
            ManifestIdxMap::new(src_manifest_idx_and_name, dst_manifest_idx_and_name);

        // Load all data sources that were created up to and including
        // `target_block` and transform them ready for insertion
        let dss: Vec<_> = self
            .table
            .clone()
            .filter(
                diesel::dsl::sql::<Bool>("lower(block_range) <= ").bind::<Integer, _>(target_block),
            )
            .select((
                &self.block_range,
                &self.manifest_idx,
                &self.param,
                &self.context,
                &self.causality_region,
                &self.done_at,
            ))
            .order_by(&self.vid)
            .load::<DsForCopy>(conn)
            .await?
            .into_iter()
            .map(|ds| ds.src_to_dst(target_block, &manifest_map, &self.namespace, &dst.namespace))
            .collect::<Result<_, _>>()?;

        // Split all dss into chunks so that we never use more than
        // `POSTGRES_MAX_PARAMETERS` bind variables per chunk
        let chunk_size = POSTGRES_MAX_PARAMETERS / CopyDsQuery::BIND_PARAMS;
        let mut count = 0;
        for chunk in dss.chunks(chunk_size) {
            let query = CopyDsQuery::new(dst, chunk)?;
            count += query.execute(conn).await?;
        }

        // If the manifest idxes remained constant, we can test that both tables have the same
        // contents.
        if src_manifest_idx_and_name == dst_manifest_idx_and_name {
            debug_assert!(
                self.load(conn, target_block)
                    .await
                    .map_err(|e| e.to_string())
                    == dst
                        .load(conn, target_block)
                        .await
                        .map_err(|e| e.to_string())
            );
        }

        Ok(count)
    }

    // Remove offchain data sources by checking the causality region, which currently uniquely
    // identifies an offchain data source.
    pub(super) async fn update_offchain_status(
        &self,
        conn: &mut AsyncPgConnection,
        data_sources: &write::DataSources,
    ) -> Result<(), StoreError> {
        for (_, dss) in &data_sources.entries {
            for ds in dss {
                let query = format!(
                    "update {} set done_at = $1 where causality_region = $2",
                    self.qname
                );

                let count = sql_query(query)
                    .bind::<Nullable<Integer>, _>(ds.done_at)
                    .bind::<Integer, _>(ds.causality_region)
                    .execute(conn)
                    .await?;

                if count > 1 {
                    return Err(internal_error!(
                    "expected to remove at most one offchain data source but would remove {}, causality region: {}",
                    count,
                    ds.causality_region
                ));
                }
            }
        }

        Ok(())
    }

    /// The current causality sequence according to the store, which is infered to be the maximum
    /// value existing in the table.
    pub(super) async fn causality_region_curr_val(
        &self,
        conn: &mut AsyncPgConnection,
    ) -> Result<Option<CausalityRegion>, StoreError> {
        // Get the maximum `causality_region` leveraging the btree index.
        Ok(self
            .table
            .clone()
            .select(&self.causality_region)
            .order_by((&self.causality_region).desc())
            .first::<CausalityRegion>(conn)
            .await
            .optional()?)
    }
}

/// Map src manifest indexes to dst manifest indexes. If the
/// destination is missing an entry, put `None` as the value for the
/// source index
struct ManifestIdxMap {
    map: HashMap<i32, (Option<i32>, String)>,
}

impl ManifestIdxMap {
    fn new(src: &[(i32, String)], dst: &[(i32, String)]) -> Self {
        let dst_idx_map: HashMap<&String, i32> =
            HashMap::from_iter(dst.iter().map(|(idx, name)| (name, *idx)));
        let map = src
            .iter()
            .map(|(src_idx, src_name)| {
                (
                    *src_idx,
                    (dst_idx_map.get(src_name).copied(), src_name.to_string()),
                )
            })
            .collect();
        ManifestIdxMap { map }
    }

    fn dst_idx(
        &self,
        src_idx: i32,
        src_nsp: &Namespace,
        src_created: BlockNumber,
        dst_nsp: &Namespace,
    ) -> Result<i32, StoreError> {
        let (dst_idx, name) = self.map.get(&src_idx).with_context(|| {
            anyhow!(
                "the source {src_nsp} does not have a template with \
              index {src_idx} but created one at block {src_created}"
            )
        })?;
        let dst_idx = dst_idx.with_context(|| {
            anyhow!(
                "the destination {dst_nsp} is missing a template with \
                name {name}. The source {src_nsp} created one at block {src_created}"
            )
        })?;
        Ok(dst_idx)
    }
}

#[derive(Queryable)]
struct DsForCopy {
    block_range: (Bound<i32>, Bound<i32>),
    idx: i32,
    param: Option<Vec<u8>>,
    context: Option<serde_json::Value>,
    causality_region: i32,
    done_at: Option<i32>,
}

impl DsForCopy {
    fn src_to_dst(
        mut self,
        target_block: BlockNumber,
        map: &ManifestIdxMap,
        src_nsp: &Namespace,
        dst_nsp: &Namespace,
    ) -> Result<Self, StoreError> {
        // unclamp block range if it ends beyond target block
        match self.block_range.1 {
            Bound::Included(block) if block > target_block => self.block_range.1 = Bound::Unbounded,
            Bound::Excluded(block) if block - 1 > target_block => {
                self.block_range.1 = Bound::Unbounded
            }
            _ => { /* use block range as is */ }
        }
        // Translate manifest index
        let src_created = match self.block_range.0 {
            Bound::Included(block) => block,
            Bound::Excluded(block) => block + 1,
            Bound::Unbounded => 0,
        };
        self.idx = map.dst_idx(self.idx, src_nsp, src_created, dst_nsp)?;
        Ok(self)
    }
}

struct CopyDsQuery<'a> {
    dst: &'a DataSourcesTable,
    dss: &'a [DsForCopy],
}

impl<'a> CopyDsQuery<'a> {
    const BIND_PARAMS: usize = 6;

    fn new(dst: &'a DataSourcesTable, dss: &'a [DsForCopy]) -> Result<Self, StoreError> {
        Ok(CopyDsQuery { dst, dss })
    }
}

impl<'a> QueryFragment<Pg> for CopyDsQuery<'a> {
    fn walk_ast<'b>(&'b self, mut out: AstPass<'_, 'b, Pg>) -> QueryResult<()> {
        out.unsafe_to_cache_prepared();
        out.push_sql("insert into ");
        out.push_sql(&self.dst.qname);
        out.push_sql(
            "(block_range, manifest_idx, param, context, causality_region, done_at) values ",
        );
        let mut first = true;
        for ds in self.dss.iter() {
            if first {
                first = false;
            } else {
                out.push_sql(", ");
            }
            out.push_sql("(");
            out.push_bind_param::<sql_types::Range<Integer>, _>(&ds.block_range)?;
            out.push_sql(", ");
            out.push_bind_param::<Integer, _>(&ds.idx)?;
            out.push_sql(", ");
            out.push_bind_param::<Nullable<Binary>, _>(&ds.param)?;
            out.push_sql(", ");
            out.push_bind_param::<Nullable<Jsonb>, _>(&ds.context)?;
            out.push_sql(", ");
            out.push_bind_param::<Integer, _>(&ds.causality_region)?;
            out.push_sql(", ");
            out.push_bind_param::<Nullable<Integer>, _>(&ds.done_at)?;
            out.push_sql(")");
        }

        Ok(())
    }
}

impl<'a> QueryId for CopyDsQuery<'a> {
    type QueryId = ();

    const HAS_STATIC_QUERY_ID: bool = false;
}
